[DAGMM] DAGMM: for arrhythmia data set

Author

kione kim

Published

October 19, 2023

Deep Autoencoding Gaussian Mixture Model for Arrhythmia dataset

### imports
import torch
from torch import nn
import numpy as np
import pandas as pd
import argparse
import sys
### data 파일
file_path = 'C:\\Users\\UOS\\Desktop\\연구\\5. 데이터\\data\\arrhythmia\\arrhythmia.data'

df = pd.read_csv(file_path, header=None)
df = df.replace('?', 0)
df = df.astype('float64')

data_array = df.values
data_array = torch.autograd.Variable(torch.from_numpy(data_array).float())
data_array.shape
torch.Size([452, 280])
parser = argparse.ArgumentParser(description='parser for argparse test')

parser.add_argument('--input_dim', type=int, default=data_array.shape[-1])
parser.add_argument('--enc_hidden_dim', type=str, default='10,2')
parser.add_argument('--dec_hidden_dim', type=str, default='10')
parser.add_argument('--est_hidden_dim', type=str, default='4, 10, 2')
parser.add_argument('--dropout', action='store_true', default=0.5)
parser.add_argument('--learning_rate', type=float, default=0.001)
parser.add_argument('--num_epoch', type=int, default=10)

if 'ipykernel_launcher' in sys.argv[0]:
    sys.argv = [sys.argv[0]]  

args = parser.parse_args()

enc_hidden_dim = args.enc_hidden_dim.split(',')
dec_hidden_dim = args.dec_hidden_dim.split(',')
est_hidden_dim = args.est_hidden_dim.split(',')

args.enc_hidden_dim_list = []
args.dec_hidden_dim_list = []
args.est_hidden_dim_list = []

args.enc_hidden_dim_list.append(args.input_dim)

for i in enc_hidden_dim:
    args.enc_hidden_dim_list.append(int(i))

args.enc_hidden_dim_list

args.dec_hidden_dim_list.append(args.enc_hidden_dim_list[-1])

for i in dec_hidden_dim:
    args.dec_hidden_dim_list.append(int(i))

args.dec_hidden_dim_list.append(args.input_dim)

args.dec_hidden_dim_list

for i in est_hidden_dim:
    args.est_hidden_dim_list.append(int(i))

args.est_hidden_dim_list

args
Namespace(input_dim=280, enc_hidden_dim='10,2', dec_hidden_dim='10', est_hidden_dim='4, 10, 2', dropout=0.5, learning_rate=0.001, num_epoch=10, enc_hidden_dim_list=[280, 10, 2], dec_hidden_dim_list=[2, 10, 280], est_hidden_dim_list=[4, 10, 2])
### compresssion network
class midlayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(midlayer, self).__init__()
        self.fc_layer   = nn.Linear(input_dim, hidden_dim)
        self.activation = nn.Tanh()
    
    def forward(self, input):
        out = self.fc_layer(input)        
        out = self.activation(out)
        return out


class Encoder(nn.Module):
    def __init__(self, hidden_dim_list):
        super(Encoder, self).__init__()
        
        layer_list = []
        for i in range(len(hidden_dim_list)-2):
            layer_list.append(midlayer(hidden_dim_list[i], hidden_dim_list[i+1]))
        
        layer_list.append(nn.Linear(hidden_dim_list[i+1], hidden_dim_list[i+2]))
        self.layer = nn.Sequential(*layer_list)

    def forward(self, input):
        out = self.layer(input)
        return out
    
class Decoder(nn.Module):
    def __init__(self, hidden_dim_list):
        super(Decoder, self).__init__()

        layer_list = []
        for i in range(len(hidden_dim_list)-2):
            layer_list.append(midlayer(hidden_dim_list[i], hidden_dim_list[i+1]))
        
        layer_list.append(midlayer(hidden_dim_list[i+1], hidden_dim_list[i+2]))
        self.layer = nn.Sequential(*layer_list)
    
    def forward(self, input):
        out = self.layer(input)
        return out

class CompressionNet(nn.Module):
    def __init__(self, enc_hidden_dim_list, dec_hidden_dim_list):
        super().__init__()
        self.encoder = Encoder(enc_hidden_dim_list)
        self.decoder = Decoder(dec_hidden_dim_list)

        self._reconstruction_loss = nn.MSELoss()

    def forward(self, input):
        out = self.encoder(input)
        out = self.decoder(out)
        return out

    def encode(self, input):
        return self.encoder(input)

    def decode(self, input):
        return self.decoder(input)

    def reconstuction_loss(self, input, input_target):
        target_hat = self(input)
        return self._reconstruction_loss(target_hat, input_target)
### reconstructed error
eps = torch.autograd.Variable(torch.FloatTensor([1.e-8]), requires_grad=False)

def relative_euclidean_distance(x1, x2, eps=eps):
    num = torch.norm(x1 - x2, p=2, dim=1)
    denom = torch.norm(x1, p=2, dim=1)
    return num / torch.max(denom, eps)

def cosine_similarity(x1, x2, eps=eps):
    dot_prod = torch.sum(x1 * x2, dim=1)
    dist_x1 = torch.norm(x1, p=2, dim=1)
    dist_x2 = torch.norm(x2, p=2, dim=1)
    return dot_prod / torch.max(dist_x1*dist_x2, eps)
### estimation network
class Estimation(nn.Module):
    def __init__(self, est_hidden_dim_list):
        super().__init__()
        
        layer_list = []
        for i in range(len(est_hidden_dim_list)-2):
            layer_list.append(midlayer(est_hidden_dim_list[i], est_hidden_dim_list[i+1]))
        
        layer_list.append(nn.Dropout(p=0.5))
        layer_list.append(nn.Linear(est_hidden_dim_list[-2], est_hidden_dim_list[-1]))
        layer_list.append(nn.Softmax())
        self.net = nn.Sequential(*layer_list)
        
    def forward(self, input):
        out = self.net(input)
        return out
### Mixture
class Mixture(nn.Module):
    def __init__(self, latent_dimension):
        super().__init__()
        self.latent_dimension = latent_dimension

        self.Phi    = np.random.random([1])
        self.Phi    = torch.from_numpy(self.Phi).float()
        self.Phi    = nn.Parameter(self.Phi, requires_grad = False)

        self.mu     = 2.*np.random.random([latent_dimension]) - 0.5
        self.mu     = torch.from_numpy(self.mu).float()
        self.mu     = nn.Parameter(self.mu, requires_grad = False)

        self.Sigma  = np.eye(latent_dimension, latent_dimension)
        self.Sigma  = torch.from_numpy(self.Sigma).float()
        self.Sigma  = nn.Parameter(self.Sigma, requires_grad = False)
        
        self.eps_Sigma  = torch.FloatTensor(np.diag([1.e-8 for _ in range(latent_dimension)]))

    def forward(self, est_inputs, with_log = True):
        batch_size, _   = est_inputs.shape
        out_values  = []
        inv_sigma   = torch.inverse(self.Sigma)
        det_sigma   = np.linalg.det(self.Sigma.data.cpu().numpy())
        det_sigma   = torch.from_numpy(det_sigma.reshape([1])).float()
        det_sigma   = torch.autograd.Variable(det_sigma)
        for est_input in est_inputs:
            diff    = (est_input - self.mu).view(-1,1)
            out     = -0.5 * torch.mm(torch.mm(diff.view(1,-1), inv_sigma), diff)
            out     = (self.Phi * torch.exp(out)) / torch.sqrt(2. * np.pi * det_sigma)
            if with_log:
                out = -torch.log(out)
            out_values.append(float(out.data.cpu().numpy()))

        out = torch.autograd.Variable(torch.FloatTensor(out_values))
        return out
    
    def _update_parameters(self, samples, affiliations):
        if not self.training:
            return

        batch_size, _ = samples.shape

        # Updating phi.
        phi = torch.mean(affiliations)
        self.Phi.data = phi.data

        # Updating mu.
        num = 0.
        for i in range(batch_size):
            z_i     = samples[i, :]
            gamma_i = affiliations[i]
            num     += gamma_i * z_i
        
        denom        = torch.sum(affiliations)
        self.mu.data = (num / denom).data

        # Updating Sigma.
        mu  = self.mu
        num = None
        for i in range(batch_size):
            z_i     = samples[i, :]
            gamma_i = affiliations[i]
            diff    = (z_i - mu).view(-1, 1)
            to_add  = gamma_i * torch.mm(diff, diff.view(1, -1))
            if num is None:
                num = to_add
            else:
                num += to_add

        denom           = torch.sum(affiliations)
        self.Sigma.data = (num / denom).data + self.eps_Sigma


class GMM(nn.Module):
    def __init__(self, num_mixtures, latent_dimension):
        super().__init__()
        self.num_mixtures       = num_mixtures
        self.latent_dimension   = latent_dimension

        mixtures        = [Mixture(latent_dimension) for _ in range(num_mixtures)]
        self.mixtures   = nn.ModuleList(mixtures)
    
    def forward(self, est_inputs):
        out = None
        for mixture in self.mixtures:
            to_add  = mixture(est_inputs, with_log = False)
            if out is None:
                out = to_add
            else:
                out += to_add
        return -torch.log(out)
    
    def _update_mixtures_parameters(self, samples, mixtures_affiliations):
        if not self.training:
            return

        for i, mixture in enumerate(self.mixtures):
            affiliations = mixtures_affiliations[:, i]
            mixture._update_parameters(samples, affiliations)
### model
class DAGMM(nn.Module):
    def __init__(self, compression_module, estimation_module, gmm_module):
        super().__init__()

        self.compressor = compression_module
        self.estimator  = estimation_module
        self.gmm        = gmm_module

    def forward(self, input):
        encoded = self.compressor.encode(input)
        decoded = self.compressor.decode(encoded)

        relative_ed     = relative_euclidean_distance(input, decoded)
        cosine_sim      = cosine_similarity(input, decoded)

        relative_ed     = relative_ed.view(-1, 1)
        cosine_sim      = relative_ed.view(-1, 1)
        latent_vectors  = torch.cat([encoded, relative_ed, cosine_sim], dim=1)

        if self.training:
            mixtures_affiliations = self.estimator(latent_vectors)
            self.gmm._update_mixtures_parameters(latent_vectors,
                                                 mixtures_affiliations)
        return self.gmm(latent_vectors)


class DAGMMArrhythmia(DAGMM):
    def __init__(self, enc_hidden_dim_list, dec_hidden_dim_list, est_hidden_dim_list):
        compressor  = CompressionNet(enc_hidden_dim_list, dec_hidden_dim_list)
        estimator   = Estimation(est_hidden_dim_list)
        gmm = GMM(num_mixtures=2, latent_dimension=4)

        super().__init__(compression_module = compressor,
                         estimation_module  = estimator,
                         gmm_module         = gmm)
### tests
def test_dagmm():
    net = DAGMMArrhythmia(args.enc_hidden_dim_list, args.dec_hidden_dim_list, args.est_hidden_dim_list)
    out = net(data_array)
    print(out)

def convert_to_var(input):
    out = torch.from_numpy(input).float()
    out = torch.autograd.Variable(out)
    return out

def test_update_mixture():
    batch_size       = 5
    latent_dimension = 7
    mix              = Mixture(latent_dimension)
    latent_vectors   = np.random.random([batch_size, latent_dimension])
    affiliations     = np.random.random([batch_size])
    latent_vectors   = convert_to_var(latent_vectors)
    affiliations     = convert_to_var(affiliations)

    for param in mix.parameters():
        print(param)

    mix.train()
    mix._update_parameters(latent_vectors, affiliations)

    for param in mix.parameters():
        print(param)


def test_forward_mixture():
    batch_size       = 5
    latent_dimension = 7

    mix = Mixture(latent_dimension)
    latent_vectors   = np.random.random([batch_size, latent_dimension])
    latent_vectors   = convert_to_var(latent_vectors)

    mix.train()
    out = mix(latent_vectors)
    print(out)


def test_update_gmm():
    batch_size      = int(5)
    latent_dimension= 7
    num_mixtures    = 2

    gmm = GMM(num_mixtures, latent_dimension)

    latent_vectors  = np.random.random([batch_size, latent_dimension])
    latent_vectors  = convert_to_var(latent_vectors)

    affiliations    = np.random.random([batch_size, num_mixtures])
    affiliations    = convert_to_var(affiliations)

    for param in gmm.parameters():
        print(param)

    gmm.train()
    gmm._update_mixtures_parameters(latent_vectors, affiliations)

    for param in gmm.parameters():
        print(param)
if __name__ == '__main__':
    test_update_mixture()
    test_forward_mixture()
    test_update_gmm()
    test_dagmm()
Parameter containing:
tensor([0.6708])
Parameter containing:
tensor([ 0.0744, -0.3463,  0.1928, -0.0517,  0.2910,  0.0668,  1.1879])
Parameter containing:
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1.]])
Parameter containing:
tensor(0.5546)
Parameter containing:
tensor([0.1851, 0.2101, 0.4965, 0.5223, 0.4495, 0.7594, 0.3953])
Parameter containing:
tensor([[ 0.0318,  0.0050,  0.0162,  0.0379,  0.0040,  0.0040,  0.0056],
        [ 0.0050,  0.0393, -0.0210,  0.0103, -0.0568,  0.0220,  0.0469],
        [ 0.0162, -0.0210,  0.1608,  0.0731,  0.0986, -0.0601, -0.0862],
        [ 0.0379,  0.0103,  0.0731,  0.0693,  0.0219, -0.0098, -0.0135],
        [ 0.0040, -0.0568,  0.0986,  0.0219,  0.1173, -0.0607, -0.0939],
        [ 0.0040,  0.0220, -0.0601, -0.0098, -0.0607,  0.0441,  0.0458],
        [ 0.0056,  0.0469, -0.0862, -0.0135, -0.0939,  0.0458,  0.0832]])
tensor([1.8491, 2.2013, 2.1552, 1.5288, 1.8039])
Parameter containing:
tensor([0.2684])
Parameter containing:
tensor([ 0.8169, -0.3349,  0.6617,  0.7210,  1.1626,  0.4094,  0.2853])
Parameter containing:
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1.]])
Parameter containing:
tensor([0.9417])
Parameter containing:
tensor([ 0.1645,  0.1663, -0.3662,  1.1244, -0.4419, -0.3331,  0.1235])
Parameter containing:
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1.]])
Parameter containing:
tensor(0.6033)
Parameter containing:
tensor([0.4998, 0.3665, 0.3353, 0.4332, 0.5558, 0.7133, 0.5177])
Parameter containing:
tensor([[ 0.0387, -0.0469, -0.0093,  0.0224,  0.0045,  0.0055,  0.0054],
        [-0.0469,  0.0889,  0.0099, -0.0315, -0.0301, -0.0062, -0.0251],
        [-0.0093,  0.0099,  0.0118,  0.0068,  0.0017,  0.0017, -0.0159],
        [ 0.0224, -0.0315,  0.0068,  0.0309,  0.0081,  0.0077, -0.0151],
        [ 0.0045, -0.0301,  0.0017,  0.0081,  0.0198,  0.0008,  0.0121],
        [ 0.0055, -0.0062,  0.0017,  0.0077,  0.0008,  0.0020, -0.0049],
        [ 0.0054, -0.0251, -0.0159, -0.0151,  0.0121, -0.0049,  0.0373]])
Parameter containing:
tensor(0.5910)
Parameter containing:
tensor([0.4718, 0.5168, 0.3861, 0.3428, 0.4769, 0.6866, 0.4324])
Parameter containing:
tensor([[ 0.0259, -0.0333, -0.0069,  0.0161,  0.0043,  0.0040,  0.0049],
        [-0.0333,  0.0850,  0.0164, -0.0423, -0.0352, -0.0108, -0.0309],
        [-0.0069,  0.0164,  0.0149, -0.0047, -0.0039, -0.0021, -0.0186],
        [ 0.0161, -0.0423, -0.0047,  0.0356,  0.0181,  0.0098,  0.0045],
        [ 0.0043, -0.0352, -0.0039,  0.0181,  0.0221,  0.0044,  0.0155],
        [ 0.0040, -0.0108, -0.0021,  0.0098,  0.0044,  0.0029,  0.0016],
        [ 0.0049, -0.0309, -0.0186,  0.0045,  0.0155,  0.0016,  0.0334]])
tensor([-20.2235, -19.5796, -18.2748, -20.0366, -20.2120,  -3.3468, -20.2282,
        -19.9478, -20.1908, -20.2319, -20.1559, -20.0203, -19.9910, -20.0385,
        -19.3059, -20.0160, -19.0049, -15.0871, -20.1826, -20.1166, -17.8514,
        -19.2502, -20.0698, -20.2161, -20.1287, -20.0696, -18.3994, -20.1507,
        -15.5779, -12.5161, -19.8662, -19.4357, -19.5880, -20.1350, -20.0799,
        -20.0312, -20.1304, -20.2003, -20.0427, -18.5765, -19.6710, -19.9432,
        -19.9821, -19.2762, -20.0808, -17.4431, -20.2103, -19.9602, -20.1087,
        -19.9409, -20.0862, -19.5724, -20.2076, -17.5777, -20.0907, -19.2899,
        -19.6751, -19.9591, -20.2265, -19.8545, -19.7321, -20.0437, -19.8994,
        -20.1703, -18.1992, -20.1769, -19.8955, -19.7147, -19.9189, -19.6573,
        -19.6324, -19.9687, -19.9866, -20.2188, -20.1788, -19.6756, -15.5464,
        -20.0430, -18.4116, -19.9975, -19.6549, -19.7680, -19.9588, -20.0924,
        -20.0798, -17.8145, -20.2242, -19.5241, -18.2970, -20.1807, -17.3323,
        -19.6756, -19.9130, -19.7627, -19.0984, -16.9743, -18.3699, -19.6322,
        -19.8802, -13.0792, -19.8571, -17.5672, -11.8144, -19.0541, -20.1147,
        -19.9259, -10.1497, -20.1374, -13.6075, -19.9318, -17.3490, -20.1207,
        -20.2279, -18.4644, -20.2026, -20.2200, -18.1918, -17.3313, -20.1266,
        -20.1501, -19.5323, -18.5575, -19.6269, -19.7861, -19.8430, -19.2241,
        -20.1233, -20.2272, -19.5726, -20.0968, -17.0089, -20.1908, -20.2254,
        -15.1696, -20.1749, -20.0739, -19.8494, -19.4763, -20.1868, -20.0521,
        -20.1850, -20.0310, -19.9391, -19.6083, -19.8503, -20.1680, -18.7109,
        -20.0696, -19.8996, -20.2076, -20.1285, -20.1943, -19.5300, -19.8305,
        -19.9620, -19.5923, -20.2098, -20.2126, -19.1490, -19.7566, -20.0215,
        -20.2253, -20.2282, -20.2123, -20.0022, -20.2195, -19.7390, -20.0973,
        -20.2036, -20.2148, -17.6385, -19.1051, -20.0537, -19.6717, -18.1206,
        -13.3478, -20.2282, -20.0395, -19.0924, -16.0135, -19.9751, -20.1905,
        -18.7753, -18.6293, -18.7644, -20.0694, -20.2277, -16.1424, -19.3795,
        -12.5818, -18.5242, -20.0463, -20.1280, -19.6429, -19.7690, -19.9775,
        -20.2092, -19.8706, -11.7017, -19.5606, -20.2114, -19.5944,  -4.7972,
        -20.2233, -18.2674, -20.1300, -19.7335, -20.0928, -19.6151, -20.1276,
        -19.2420, -19.9885, -20.2277, -17.3811, -20.2235, -20.0476, -18.0080,
        -18.6992, -19.0452, -18.5626, -18.9380, -20.2170, -19.9416, -19.5714,
        -20.2123, -15.2158, -20.0069, -19.6128, -20.2222, -20.2066, -19.9658,
        -19.9399, -20.2157, -19.9143, -20.1029, -12.2647, -20.0939, -19.8378,
        -19.9069, -18.3312, -20.1544, -12.2160, -19.5462, -17.9280, -20.0103,
        -12.3935, -20.0941, -19.3605, -17.0206, -19.7714, -20.1615, -17.3467,
        -12.4248, -12.6264, -16.0423, -20.2009, -18.8140, -12.9780, -20.1700,
        -20.1144, -18.0521, -19.5850, -20.1609, -20.2211, -19.2474, -19.4934,
        -20.2263, -20.2281, -20.2141, -19.8640, -20.2276, -18.4989, -20.1412,
        -20.2073, -20.2119, -20.0940, -19.2913, -19.7819, -20.0905, -19.4009,
        -20.1439, -20.2230, -20.1651, -20.2034,  -3.1711, -14.8115, -20.1249,
        -20.2278, -19.2863, -19.2409, -20.1818, -18.9364, -19.7598, -19.9828,
        -19.8661, -19.4412, -14.2111, -19.1426, -18.7352, -20.1078, -10.2869,
        -19.4248, -19.9825, -14.9699, -19.1416, -19.9673, -20.1040, -18.8180,
        -17.7334, -20.2227, -18.6138, -19.6634, -19.5755, -20.2135, -20.1384,
        -19.1177, -17.2645, -19.8158, -20.1971, -19.5122, -16.0656, -19.9423,
        -17.3143, -14.3509, -19.8850, -19.9810, -19.0178, -11.9086, -18.4911,
        -17.3414, -19.4352, -19.6163, -19.8149, -19.8486, -20.0638, -20.1256,
        -18.8305, -20.0627, -20.1914, -20.1836, -19.9506, -20.1327, -20.1412,
        -20.1764, -17.9564, -20.2157, -19.9270, -19.8813, -14.8900, -18.4771,
        -15.0317, -16.0351, -16.3937, -19.8136, -17.3297, -20.1020, -15.2023,
        -19.8852, -20.1410, -15.1101, -20.0787, -15.4456, -18.9437, -20.2281,
        -20.1290, -19.0569, -11.3157, -19.6499, -17.0891, -11.9564, -20.1864,
        -18.3846, -20.2272, -20.0108, -19.9568, -20.1103, -12.0232, -18.8302,
        -20.1996, -19.6827, -17.4583, -13.0812, -20.1706, -19.1465, -20.2258,
        -19.6770, -11.9548, -19.2042, -12.0349, -19.1939, -19.4332, -18.3247,
        -20.2003, -18.4369, -20.0828, -14.0724, -20.2177, -20.1335, -20.2074,
        -20.0335, -19.9797, -12.7783, -19.4577, -14.7200, -19.6989, -20.0652,
        -19.8943, -19.4418, -19.9583, -18.4946, -19.9660, -20.1008, -19.9310,
        -20.1120, -20.0183, -19.1694, -19.2284, -10.0634, -20.1168, -20.0106,
         -4.8029, -20.0988, -20.2199, -19.4736, -12.5663, -20.2186, -19.4957,
        -20.1714, -17.5186, -20.1014, -11.8549, -20.2172, -19.9964, -12.3047,
        -20.1148, -20.0857, -20.2281, -19.9667, -20.0292, -20.1981, -20.0960,
        -19.1063, -20.0809, -19.7993, -20.0791, -19.0565, -20.1956, -17.5403,
        -20.1692, -17.5986, -20.1645, -19.9429])
C:\Users\UOS\anaconda3\Lib\site-packages\torch\nn\modules\container.py:217: UserWarning:

Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.

Ref

  • https://openreview.net/forum?id=BJJLHbb0-